{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d006b2ea-9dfe-49c7-88a9-a5a0775185fd",
   "metadata": {},
   "source": [
    "# Additional End of week Exercise - week 2\n",
    "\n",
    "Now use everything you've learned from Week 2 to build a full prototype for the technical question/answerer you built in Week 1 Exercise.\n",
    "\n",
    "This should include a Gradio UI, streaming, use of the system prompt to add expertise, and the ability to switch between models. Bonus points if you can demonstrate use of a tool!\n",
    "\n",
    "If you feel bold, see if you can add audio input so you can talk to it, and have it respond with audio. ChatGPT or Claude can help you, or email me if you have questions.\n",
    "\n",
    "I will publish a full solution here soon - unless someone beats me to it...\n",
    "\n",
    "There are so many commercial applications for this, from a language tutor, to a company onboarding solution, to a companion AI to a course (like this one!) I can't wait to see your results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a07e7793-b8f5-44f4-aded-5562f633271a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports\n",
    "import os\n",
    "import json\n",
    "import sqlite3\n",
    "import requests\n",
    "from datetime import datetime\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import gradio as gr\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "05327b96",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OpenAI API Key found: sk-proj-...\n",
      "OpenWeather API Key found: c2fcd...\n",
      "ExchangeRate API Key found: ce0f6...\n"
     ]
    }
   ],
   "source": [
    "load_dotenv(override=True)\n",
    "\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "openweather_api_key = os.getenv('OPENWEATHER_API_KEY')\n",
    "exchangerate_api_key = os.getenv('EXCHANGERATE_API_KEY')\n",
    "\n",
    "if openai_api_key:\n",
    "    print(f\"OpenAI API Key found: {openai_api_key[:8]}...\")\n",
    "else:\n",
    "    print(\"OpenAI API Key not set\")\n",
    "\n",
    "if openweather_api_key:\n",
    "    print(f\"OpenWeather API Key found: {openweather_api_key[:5]}...\")\n",
    "else:\n",
    "    print(\"OpenWeather API Key not set - Get one at https://openweathermap.org/api\")\n",
    "\n",
    "if exchangerate_api_key:\n",
    "    print(f\"ExchangeRate API Key found: {exchangerate_api_key[:5]}...\")\n",
    "else:\n",
    "    print(\"ExchangeRate API Key not set - Get one at https://www.exchangerate-api.com/\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "592c421e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Ollama is running - Llama model available\n"
     ]
    }
   ],
   "source": [
    "openai = OpenAI()\n",
    "\n",
    "ollama_url = \"http://localhost:11434/v1\"\n",
    "ollama = OpenAI(api_key=\"ollama\", base_url=ollama_url)\n",
    "\n",
    "# Test if Ollama is available\n",
    "ollama_available = False\n",
    "test_response = requests.get(\"http://localhost:11434/\", timeout=2)\n",
    "\n",
    "if test_response.status_code == 200:\n",
    "    ollama_available = True\n",
    "    print(\"Ollama is running - Llama model available\")\n",
    "else:\n",
    "    print(\"Ollama is not responding - Only GPT will be available\")\n",
    "\n",
    "\n",
    "DB = \"bookings.db\"\n",
    "\n",
    "MODELS = {\n",
    "    \"GPT\": \"gpt-4.1-mini\",\n",
    "    \"Llama\": \"llama3.2\"\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "13e11560",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Database created successfully\n",
      "Added 10 sample ticket prices\n"
     ]
    }
   ],
   "source": [
    "# Create database and tables\n",
    "def setup_database():\n",
    "    with sqlite3.connect(DB) as conn:\n",
    "        cursor = conn.cursor()\n",
    "        \n",
    "        # Create prices table\n",
    "        cursor.execute('''\n",
    "            CREATE TABLE IF NOT EXISTS prices (\n",
    "                city TEXT PRIMARY KEY,\n",
    "                price REAL\n",
    "            )\n",
    "        ''')\n",
    "        \n",
    "        # Create bookings table\n",
    "        cursor.execute('''\n",
    "            CREATE TABLE IF NOT EXISTS bookings (\n",
    "                id INTEGER PRIMARY KEY AUTOINCREMENT,\n",
    "                city TEXT,\n",
    "                passenger_name TEXT,\n",
    "                travel_date TEXT,\n",
    "                booking_date TEXT,\n",
    "                status TEXT\n",
    "            )\n",
    "        ''')\n",
    "        \n",
    "        conn.commit()\n",
    "    print(\"Database created successfully\")\n",
    "\n",
    "# Populate sample ticket prices\n",
    "def populate_sample_data():\n",
    "    ticket_prices = {\n",
    "        \"london\": 799,\n",
    "        \"paris\": 899,\n",
    "        \"tokyo\": 1420,\n",
    "        \"sydney\": 2999,\n",
    "        \"berlin\": 499,\n",
    "        \"rome\": 650,\n",
    "        \"new york\": 450,\n",
    "        \"dubai\": 1200,\n",
    "        \"singapore\": 1350,\n",
    "        \"barcelona\": 720\n",
    "    }\n",
    "    \n",
    "    with sqlite3.connect(DB) as conn:\n",
    "        cursor = conn.cursor()\n",
    "        for city, price in ticket_prices.items():\n",
    "            cursor.execute(\n",
    "                'INSERT OR REPLACE INTO prices (city, price) VALUES (?, ?)',\n",
    "                (city.lower(), price)\n",
    "            )\n",
    "        conn.commit()\n",
    "    print(f\"Added {len(ticket_prices)} sample ticket prices\")\n",
    "\n",
    "# Run setup\n",
    "setup_database()\n",
    "populate_sample_data()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "e6328f43",
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"\"\"\n",
    "You are a helpful assistant for FlightAI Airlines.\n",
    "You can help customers with:\n",
    "- Checking ticket prices\n",
    "- Booking flights\n",
    "- Checking booking status\n",
    "- Getting destination information\n",
    "- Checking weather forecasts\n",
    "- Converting currency\n",
    "\n",
    "Always be courteous and professional.\n",
    "If you need to use a tool, use it to provide accurate information.\n",
    "When booking tickets, always confirm the details with the customer.\n",
    "Your first response should be a greeting and a brief introduction of what you can do.\n",
    "\"\"\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3e9637f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ticket_price(city):\n",
    "    \"\"\"Get the price of a ticket to the specified city\"\"\"\n",
    "    print(f\"Tool called: get_ticket_price({city})\")\n",
    "    with sqlite3.connect(DB) as conn:\n",
    "        cursor = conn.cursor()\n",
    "        cursor.execute('SELECT price FROM prices WHERE city = ?', (city.lower(),))\n",
    "        result = cursor.fetchone()\n",
    "        if result:\n",
    "            return f\"A ticket to {city.title()} costs ${result[0]}\"\n",
    "        else:\n",
    "            return f\"Sorry, we don't have pricing information for {city.title()} at this time.\"\n",
    "\n",
    "def book_ticket(city, passenger_name, travel_date):\n",
    "    \"\"\"Book a ticket for a passenger\"\"\"\n",
    "    print(f\"Tool called: book_ticket({city}, {passenger_name}, {travel_date})\")\n",
    "    \n",
    "    with sqlite3.connect(DB) as conn:\n",
    "        cursor = conn.cursor()\n",
    "        cursor.execute('SELECT price FROM prices WHERE city = ?', (city.lower(),))\n",
    "        price_result = cursor.fetchone()\n",
    "        \n",
    "        if not price_result:\n",
    "            return f\"Sorry, we don't fly to {city.title()} at this time.\"\n",
    "        \n",
    "        # Create booking\n",
    "        booking_date = datetime.now().strftime('%Y-%m-%d')\n",
    "        cursor.execute(\n",
    "            'INSERT INTO bookings (city, passenger_name, travel_date, booking_date, status) VALUES (?, ?, ?, ?, ?)',\n",
    "            (city.lower(), passenger_name, travel_date, booking_date, 'confirmed')\n",
    "        )\n",
    "        booking_id = cursor.lastrowid\n",
    "        conn.commit()\n",
    "        \n",
    "        return f\"Booking confirmed! Booking ID: {booking_id}. Passenger: {passenger_name}, Destination: {city.title()}, Travel Date: {travel_date}, Price: ${price_result[0]}\"\n",
    "\n",
    "def get_booking_status(booking_id):\n",
    "    \"\"\"Check the status of a booking\"\"\"\n",
    "    print(f\"Tool called: get_booking_status({booking_id})\")\n",
    "    with sqlite3.connect(DB) as conn:\n",
    "        cursor = conn.cursor()\n",
    "        cursor.execute(\n",
    "            'SELECT id, city, passenger_name, travel_date, booking_date, status FROM bookings WHERE id = ?',\n",
    "            (booking_id,)\n",
    "        )\n",
    "        result = cursor.fetchone()\n",
    "        \n",
    "        if result:\n",
    "            return f\"Booking #{result[0]} - Passenger: {result[2]}, Destination: {result[1].title()}, Travel Date: {result[3]}, Status: {result[5].upper()}\"\n",
    "        else:\n",
    "            return f\"No booking found with ID {booking_id}\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "19bd5906",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_destination_info(city):\n",
    "    \"\"\"Get information about a destination using REST Countries API\"\"\"\n",
    "    print(f\"Tool called: get_destination_info({city})\")\n",
    "    \n",
    "    # Map common cities to countries\n",
    "    city_to_country = {\n",
    "        \"london\": \"United Kingdom\",\n",
    "        \"paris\": \"France\",\n",
    "        \"tokyo\": \"Japan\",\n",
    "        \"sydney\": \"Australia\",\n",
    "        \"berlin\": \"Germany\",\n",
    "        \"rome\": \"Italy\",\n",
    "        \"new york\": \"United States\",\n",
    "        \"dubai\": \"United Arab Emirates\",\n",
    "        \"singapore\": \"Singapore\",\n",
    "        \"barcelona\": \"Spain\"\n",
    "    }\n",
    "    \n",
    "    country = city_to_country.get(city.lower())\n",
    "    if not country:\n",
    "        return f\"Sorry, I don't have detailed information about {city.title()} at this time.\"\n",
    "    \n",
    "    try:\n",
    "        response = requests.get(f\"https://restcountries.com/v3.1/name/{country}\", timeout=5)\n",
    "        if response.status_code == 200:\n",
    "            data = response.json()[0]\n",
    "            name = data.get('name', {}).get('common', country)\n",
    "            capital = data.get('capital', ['N/A'])[0]\n",
    "            region = data.get('region', 'N/A')\n",
    "            languages = ', '.join(data.get('languages', {}).values())\n",
    "            currency_info = data.get('currencies', {})\n",
    "            currency = list(currency_info.keys())[0] if currency_info else 'N/A'\n",
    "            currency_name = currency_info[currency].get('name', 'N/A') if currency_info else 'N/A'\n",
    "            timezone = data.get('timezones', ['N/A'])[0]\n",
    "            \n",
    "            return f\"\"\"{city.title()} is in {name}. \n",
    "                        Capital: {capital}\n",
    "                        Region: {region}\n",
    "                        Languages: {languages}\n",
    "                        Currency: {currency_name} ({currency})\n",
    "                        Timezone: {timezone}\"\"\"\n",
    "        else:\n",
    "            return f\"Unable to retrieve information about {city.title()}\"\n",
    "    except Exception as e:\n",
    "        return f\"Error fetching destination info: {str(e)}\"\n",
    "\n",
    "def get_weather_forecast(city):\n",
    "    \"\"\"Get weather forecast for a destination using OpenWeatherMap API\"\"\"\n",
    "    print(f\"🔧 Tool called: get_weather_forecast({city})\")\n",
    "    \n",
    "    if not openweather_api_key:\n",
    "        return \"Weather service unavailable. Please set OPENWEATHER_API_KEY in your .env file.\"\n",
    "    \n",
    "    try:\n",
    "        # Get coordinates first\n",
    "        geo_url = f\"http://api.openweathermap.org/geo/1.0/direct?q={city}&limit=1&appid={openweather_api_key}\"\n",
    "        geo_response = requests.get(geo_url, timeout=5)\n",
    "        \n",
    "        if geo_response.status_code == 200 and geo_response.json():\n",
    "            geo_data = geo_response.json()[0]\n",
    "            lat, lon = geo_data['lat'], geo_data['lon']\n",
    "            \n",
    "            # Get weather forecast\n",
    "            weather_url = f\"http://api.openweathermap.org/data/2.5/forecast?lat={lat}&lon={lon}&appid={openweather_api_key}&units=metric\"\n",
    "            weather_response = requests.get(weather_url, timeout=5)\n",
    "            \n",
    "            if weather_response.status_code == 200:\n",
    "                weather_data = weather_response.json()\n",
    "                forecasts = weather_data['list'][:5]  # Next 5 forecasts (15 hours)\n",
    "                \n",
    "                forecast_text = f\"Weather forecast for {city.title()}:\\n\"\n",
    "                for forecast in forecasts:\n",
    "                    time = forecast['dt_txt']\n",
    "                    temp = forecast['main']['temp']\n",
    "                    description = forecast['weather'][0]['description']\n",
    "                    forecast_text += f\"\\n{time}: {temp}°C, {description}\"\n",
    "                \n",
    "                return forecast_text\n",
    "        \n",
    "        return f\"Unable to retrieve weather forecast for {city.title()}\"\n",
    "    except Exception as e:\n",
    "        return f\"Error fetching weather: {str(e)}\"\n",
    "\n",
    "def convert_currency(amount, from_currency, to_currency):\n",
    "    \"\"\"Convert currency using ExchangeRate API\"\"\"\n",
    "    print(f\"Tool called: convert_currency({amount} {from_currency} to {to_currency})\")\n",
    "    \n",
    "    if not exchangerate_api_key:\n",
    "        return \"Currency conversion unavailable. Please set EXCHANGERATE_API_KEY in your .env file.\"\n",
    "    \n",
    "    try:\n",
    "        url = f\"https://v6.exchangerate-api.com/v6/{exchangerate_api_key}/pair/{from_currency.upper()}/{to_currency.upper()}/{amount}\"\n",
    "        response = requests.get(url, timeout=5)\n",
    "        \n",
    "        if response.status_code == 200:\n",
    "            data = response.json()\n",
    "            if data.get('result') == 'success':\n",
    "                converted = data['conversion_result']\n",
    "                rate = data['conversion_rate']\n",
    "                return f\"{amount} {from_currency.upper()} = {converted:.2f} {to_currency.upper()} (Rate: {rate:.4f})\"\n",
    "        \n",
    "        return \"Unable to convert currency. Please check the currency codes.\"\n",
    "    except Exception as e:\n",
    "        return f\"Error converting currency: {str(e)}\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ae8a1169",
   "metadata": {},
   "source": [
    "Tools Calling For OPEN AI function call"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9f1b6137",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Defined 6 tools for the assistant\n"
     ]
    }
   ],
   "source": [
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_ticket_price\",\n",
    "            \"description\": \"Get the price of a return ticket to a destination city\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The destination city\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"city\"],\n",
    "                \"additionalProperties\": False\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"book_ticket\",\n",
    "            \"description\": \"Book a flight ticket for a passenger\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The destination city\"\n",
    "                    },\n",
    "                    \"passenger_name\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The full name of the passenger\"\n",
    "                    },\n",
    "                    \"travel_date\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The travel date in YYYY-MM-DD format\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"city\", \"passenger_name\", \"travel_date\"],\n",
    "                \"additionalProperties\": False\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_booking_status\",\n",
    "            \"description\": \"Check the status of a booking using the booking ID\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"booking_id\": {\n",
    "                        \"type\": \"integer\",\n",
    "                        \"description\": \"The booking ID number\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"booking_id\"],\n",
    "                \"additionalProperties\": False\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_destination_info\",\n",
    "            \"description\": \"Get detailed information about a destination city including country, capital, language, currency, and timezone\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The destination city\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"city\"],\n",
    "                \"additionalProperties\": False\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_weather_forecast\",\n",
    "            \"description\": \"Get the weather forecast for a destination city\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"city\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The destination city\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"city\"],\n",
    "                \"additionalProperties\": False\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"convert_currency\",\n",
    "            \"description\": \"Convert an amount from one currency to another\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"amount\": {\n",
    "                        \"type\": \"number\",\n",
    "                        \"description\": \"The amount to convert\"\n",
    "                    },\n",
    "                    \"from_currency\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The source currency code (e.g., USD, EUR, GBP)\"\n",
    "                    },\n",
    "                    \"to_currency\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"The target currency code (e.g., USD, EUR, GBP)\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"amount\", \"from_currency\", \"to_currency\"],\n",
    "                \"additionalProperties\": False\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "]\n",
    "\n",
    "print(f\"Defined {len(tools)} tools for the assistant\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a5cb2a4",
   "metadata": {},
   "source": [
    "Tool Handler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "15784260",
   "metadata": {},
   "outputs": [],
   "source": [
    "def handle_tool_calls(message):\n",
    "    \"\"\"Process all tool calls in a message and return responses\"\"\"\n",
    "    responses = []\n",
    "    \n",
    "    for tool_call in message.tool_calls:\n",
    "        function_name = tool_call.function.name\n",
    "        arguments = json.loads(tool_call.function.arguments)\n",
    "        \n",
    "        # Call the appropriate function\n",
    "        if function_name == \"get_ticket_price\":\n",
    "            result = get_ticket_price(arguments.get('city'))\n",
    "        elif function_name == \"book_ticket\":\n",
    "            result = book_ticket(\n",
    "                arguments.get('city'),\n",
    "                arguments.get('passenger_name'),\n",
    "                arguments.get('travel_date')\n",
    "            )\n",
    "        elif function_name == \"get_booking_status\":\n",
    "            result = get_booking_status(arguments.get('booking_id'))\n",
    "        elif function_name == \"get_destination_info\":\n",
    "            result = get_destination_info(arguments.get('city'))\n",
    "        elif function_name == \"get_weather_forecast\":\n",
    "            result = get_weather_forecast(arguments.get('city'))\n",
    "        elif function_name == \"convert_currency\":\n",
    "            result = convert_currency(\n",
    "                arguments.get('amount'),\n",
    "                arguments.get('from_currency'),\n",
    "                arguments.get('to_currency')\n",
    "            )\n",
    "        else:\n",
    "            result = f\"Unknown function: {function_name}\"\n",
    "        \n",
    "        responses.append({\n",
    "            \"role\": \"tool\",\n",
    "            \"content\": result,\n",
    "            \"tool_call_id\": tool_call.id\n",
    "        })\n",
    "    \n",
    "    return responses\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8394a6c0",
   "metadata": {},
   "source": [
    "Chat Function wuth streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e38c6f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def chat(message, history, model_choice):\n",
    "    \"\"\"Main chat function with streaming and tool support\"\"\"\n",
    "    \n",
    "    # Select the appropriate client\n",
    "    if model_choice == \"Llama\" and not ollama_available:\n",
    "        yield \"Llama is not available. Please start Ollama with 'ollama serve' or select GPT.\"\n",
    "        return\n",
    "    \n",
    "    client = ollama if model_choice == \"Llama\" else openai\n",
    "    model = MODELS[model_choice]\n",
    "    \n",
    "    history = [{\"role\": h[\"role\"], \"content\": h[\"content\"]} for h in history]\n",
    "    messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
    "    \n",
    "    response = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages,\n",
    "        tools=tools\n",
    "    )\n",
    "    \n",
    "    # Handle tool calls in a loop\n",
    "    while response.choices[0].finish_reason == \"tool_calls\":\n",
    "        assistant_message = response.choices[0].message\n",
    "        tool_responses = handle_tool_calls(assistant_message)\n",
    "        \n",
    "        messages.append(assistant_message)\n",
    "        messages.extend(tool_responses)\n",
    "        \n",
    "        # Get next response\n",
    "        response = client.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=messages,\n",
    "            tools=tools\n",
    "        )\n",
    "    \n",
    "    # Get final response with streaming\n",
    "    result = response.choices[0].message.content or \"\"\n",
    "    \n",
    "    # Stream the result\n",
    "    stream = client.chat.completions.create(\n",
    "        model=model,\n",
    "        messages=messages + [{\"role\": \"assistant\", \"content\": result}][:-1],\n",
    "        stream=True\n",
    "    )\n",
    "    \n",
    "    streamed_response = \"\"\n",
    "    for chunk in stream:\n",
    "        streamed_response += chunk.choices[0].delta.content or ''\n",
    "        yield streamed_response\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "06d68b99",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define this variable and then pass js=force_dark_mode when creating the Interface\n",
    "\n",
    "force_dark_mode = \"\"\"\n",
    "function refresh() {\n",
    "    const url = new URL(window.location);\n",
    "    if (url.searchParams.get('__theme') !== 'dark') {\n",
    "        url.searchParams.set('__theme', 'dark');\n",
    "        window.location.href = url.href;\n",
    "    }\n",
    "}\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87ed20e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* Running on local URL:  http://127.0.0.1:7879\n",
      "* To create a public link, set `share=True` in `launch()`.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div><iframe src=\"http://127.0.0.1:7879/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": []
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tool called: get_ticket_price(London)\n",
      "Tool called: get_destination_info(London)\n",
      "Tool called: convert_currency(799 USD to NGN)\n"
     ]
    }
   ],
   "source": [
    "# Create the Gradio interface\n",
    "with gr.Blocks(title=\"FlightAI Assistant\", js=force_dark_mode) as demo:\n",
    "    gr.Markdown(\"\"\"\n",
    "    # ✈️ FlightAI Assistant\n",
    "    \n",
    "    Welcome to FlightAI! I can help you with:\n",
    "    Checking ticket prices, Booking flights, Checking booking status, Getting destination information\n",
    "    Checking weather forecasts, Converting currencies\n",
    "\n",
    "    \"\"\")\n",
    "    \n",
    "    with gr.Row():\n",
    "        model_selector = gr.Dropdown(\n",
    "            choices=[\"GPT\", \"Llama\"],\n",
    "            value=\"GPT\",\n",
    "            label=\"Select Model\",\n",
    "            info=\"Choose between OpenAI GPT or local Llama (via Ollama)\"\n",
    "        )\n",
    "    \n",
    "    chatbot = gr.ChatInterface(\n",
    "        fn=chat,\n",
    "        additional_inputs=[model_selector],\n",
    "        type=\"messages\",\n",
    "        examples=[\n",
    "            [\"What's the price to London?\", \"GPT\"],\n",
    "            [\"Tell me about Paris\", \"GPT\"],\n",
    "            [\"What's the weather like in Tokyo?\", \"Llama\"],\n",
    "            [\"Convert 799 USD to EUR\", \"GPT\"],\n",
    "            [\"Book me a ticket to Berlin for John Smith on 2024-12-15\", \"GPT\"],\n",
    "            [\"Check booking status for booking ID 1\", \"Llama\"]\n",
    "        ],\n",
    "        title=None,\n",
    "        description=None\n",
    "    )\n",
    "\n",
    "demo.launch()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
